import os
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, AutoConfig
import argparse
import torch
import warnings
import pynvml
from tqdm import tqdm
from safetensors.torch import save_file
pynvml.nvmlInit()
import logging
os.environ['HF_HOME'] = ''
warnings.simplefilter("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""


SHORT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe."""


MISTRAL_SYSTEM_PROMPT = """You are an AI assistant. Always assist with care, respect, and truth. Please respond with utmost utility yet securely and avoid harmful, unethical, prejudiced, or negative content. Also ensure replies promote fairness and positivity."""



def logging_cuda_memory_usage():
    n_gpus = pynvml.nvmlDeviceGetCount()
    for i in range(n_gpus):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        logging.info("GPU {}: {:.2f} GB / {:.2f} GB".format(i, meminfo.used / 1024 ** 3, meminfo.total / 1024 ** 3))

def prepend_sys_prompt(sentence):
    messages = [{'role': 'user', 'content': sentence.strip()}]
    return messages

def forward(model, toker, messages):
    input_text = toker.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    input_ids = torch.tensor(
        toker.convert_tokens_to_ids(toker.tokenize(input_text)),
        dtype=torch.long,
    ).unsqueeze(0).to(model.device)

    outputs = model(
        input_ids,
        attention_mask=input_ids.new_ones(input_ids.size(), dtype=model.dtype),
        return_dict=True,
        output_hidden_states=True,
    )
    hidden_states = [e[0].detach().half().cpu() for e in outputs.hidden_states[1:]]

    return hidden_states



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_paths", type=str, required=True)
    args = parser.parse_args()

    model = AutoModelForCausalLM.from_pretrained(
            args.pretrained_model_paths,
            # torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            use_safetensors=True,
            # device_map="auto",
            # attn_implementation="flash_attention_2" if torch.cuda.is_bf16_supported() else None,
            
        )
    toker = AutoTokenizer.from_pretrained(args.pretrained_model_paths)
    generation_config_file = 'generation_configs/{}.json'.format(args.pretrained_model_paths)
    generation_config = json.load(open(generation_config_file))
    chat_template_file = generation_config['chat_template']
    chat_template = open(chat_template_file).read()
    chat_template = chat_template.replace('    ', '').replace('\n', '')
    toker.chat_template = chat_template
    

    with open('dataset/100-answerable.json', 'r') as f:
        answerable_data = json.load(f)
        answerable_data = [el['question'] for el in answerable_data]

    with open('dataset/100-unanswerable.json', 'r') as f:
        unanswerable_data = json.load(f)
        unanswerable_data = [el['question'] for el in unanswerable_data]



    all_answerable_queries = [prepend_sys_prompt(l) for l in answerable_data]
    all_unanswerable_queries = [prepend_sys_prompt(l) for l in unanswerable_data]
    

    logging.info(f"Running")

    # harmful_tensors = {}
    # for idx, messages in tqdm(enumerate(all_harmful_queries),
    #                         total=len(all_harmful_queries), dynamic_ncols=True):
    #     hidden_states = forward(model, toker, messages)
    #     for i, hs in enumerate(hidden_states):
    #         harmful_tensors[f'sample.{idx}_layer.{i}'] = hs
    # save_file(harmful_tensors, f'tensor/llama2-7b-chat_harmful.safetensors')

    # harmless_tensors = {}
    # for idx, messages in tqdm(enumerate(all_harmless_queries),
    #                         total=len(all_harmless_queries), dynamic_ncols=True):
    #     hidden_states = forward(model, toker, messages)
    #     for i, hs in enumerate(hidden_states):
    #         harmless_tensors[f'sample.{idx}_layer.{i}'] = hs
    # save_file(harmless_tensors, f'tensor/llama2-7b-chat_harmless.safetensors')

    # obscure_tensors = {}
    # for idx, messages in tqdm(enumerate(all_obscure_queries),
    #                           total=len(all_obscure_queries), dynamic_ncols=True):
    #     hidden_states = forward(model, toker, messages)
    #     for i, hs in enumerate(hidden_states):
    #         obscure_tensors[f'sample.{idx}_layer.{i}'] = hs
    # save_file(obscure_tensors, f'tensor/llama2-7b-chat_obscure.safetensors')

    answerable_tensors = {}
    for idx, messages in tqdm(enumerate(all_answerable_queries),
                            total=len(all_answerable_queries), dynamic_ncols=True):
        hidden_states = forward(model, toker, messages)
        for i, hs in enumerate(hidden_states):
            answerable_tensors[f'sample.{idx}_layer.{i}'] = hs
    save_file(answerable_tensors, f'tensor/llama2-7b-chat_answerable.safetensors')

    unanswerable_tensors = {}
    for idx, messages in tqdm(enumerate(all_unanswerable_queries),
                            total=len(all_unanswerable_queries), dynamic_ncols=True):
        hidden_states = forward(model, toker, messages)
        for i, hs in enumerate(hidden_states):
            unanswerable_tensors[f'sample.{idx}_layer.{i}'] = hs
    save_file(unanswerable_tensors, f'tensor/llama2-7b-chat_unanswerable.safetensors')